(* Arbitrary-precision integer implementation. 
 *
 * A positive integer is represented as an int64 array in little-endian cell order, using only lower 32 bits of each cell.
 * int64 is used to store intermediate results, base of arithmetics being 2^32.
 *
 * Simplicity is overall more important than performance, so multiplication currently uses the naive O(N^2) algorithm,
 * Roughly, we have the following:
 * 0 to 100,000 bits: naive algorithm
 * 100,000 to 1,000,000 bits: Karatsuba
 * 1,000,000 bits and above: Strassen
 *
 *)

module Bigint =

let base = 0x100000000L
let base_not = 0xffffffffL
let base_log = 32 (* base_log = log base *)
    
(* we define some alias because infix notation is easier to read *)
let ( ++ ) = Int64.add
let ( -- ) = Int64.sub
let ( ** ) = Int64.mul
let ( >> ) = Int64.shift_right

let zero = Array.make 0 0L
let one = [|1L|]

(* Get ith cell or 0 if out of bounds 
   NOTE: calling this within a loop may hurt performance on architectures without good branch prediction *)
let get a i n =
  if i >= n then
    0L
  else
    a[i]
      
(* addition of integers of arbitrary size
   NOTE: we could restrict addition to integers of the same size, which would remove the need for the "get" function *)
let add a b =
  let carry = ref 0L in
  let na = Array.length a in 
  let nb = Array.length b in 
  let size = max na nb in
  let ret = Array.make size 0L in
  let () = 
   for 0 (size - 1) 
    (fun i ->                          (* we don't use "ifs" inside the loop for better instruction pipelining behavior *)
       let result = get a i na ++ get b i nb ++ !carry in   (* adds some blocks together, taking the carry into account *)
       let () = ret[i] <- Int64.logand result base_not in   (* extracting the lower 32 bits *)
       carry := result >> base_log) in                      (* extracting the carry *)
  if !carry = 0L then
    ret
  else
    Array.append ret [| 1L |]

(* subtraction *)
let sub a b =
  let borrow = ref 0L in
  let na = Array.length a in 
  let nb = Array.length b in 
  let size = max na nb in
  let actual_size = ref 0 in (* we want to return a big int without trailing zeros, so we calculate the actual size as we go *)
  let ret = Array.make size 0L in
  let () = 
   for 0 (size - 1)
    (fun i ->                          (* we try not to use "ifs" inside the loop for better instruction pipelining behavior *)
       let result = get a i na -- get b i nb ++ !borrow in   (* subtracts some blocks together, taking the borrow into account *)
       let () = ret[i] <- Int64.logand result base_not in    (* extracting the lower 32 bits *)
       let () = 
        if ret[i] = 0L then   (* we're supposed to avoid conditional jumps inside the loop, but here we do it anyway. NOTE: may hurt performance on some architectures *)
          ()
        else 
          actual_size := i + 1 in 
       borrow := if Int64.compare result 0L >= 0 then 0L else -1L) in 
  if !actual_size <> size then 
    Array.sub ret 0 !actual_size
  else 
    ret

(* Multiplication by a 'digit' x *)
let mul_digit a x =
  let carry = ref 0L in
  let size = Array.length a in
  let ret = Array.make size 0L in
  let () = 
   for 0 (size - 1)
    (fun i -> 
       let result = x ** a[i] ++ !carry in
       let () = ret[i] <- Int64.logand result base_not in 
       carry := result >> base_log) in 
  if !carry = 0L then
    ret
  else
    Array.append ret [|!carry|]

(* multiplication of integers of arbitrary size by the naive algorithm *)
let mul a b =
  if a = zero then
    zero
  else 
    (* Shift by n cells, thus multiplicating by base^n *)
    let shift a n = Array.append (Array.make n 0L) a in 

    let ret = ref zero in
    let () = 
     for 0 (Array.length b - 1)
      (fun i -> 
         let result = shift (mul_digit a b[i]) i in
         ret := add !ret result) in 
    !ret

(* Comparison, allowing trailing zeros *)
let compare a b =
  (* compare starting by index i *)
  let an = Array.length a in
  let bn = Array.length b in
  let rec cmp a b i =
    if i < 0 then
      0
    else
      let c = Int64.compare (get a i an) (get b i bn) in
      if c <> 0 then
        c
      else
        cmp a b (i-1)
  in
  cmp a b (max an bn - 1)

(* get ith bit of x *)
let get_bit x i =
  let cell = i / base_log in
  let bit = i mod base_log in
  Int64.logand (x[cell] >> bit) 1L

(* shifting left bigint by n bits *)
let shift_left a n =
  let carry = ref 0L in
  let size = Array.length a in
  let digits_shift = n / base_log in
  let bits_shift = n mod base_log in
  let ret = Array.make (size + digits_shift) 0L in
  let () = 
   for 0 (size - 1)
    (fun i -> 
       let result = Int64.shift_left a[i] bits_shift ++ !carry in
       let () = ret[i + digits_shift] <- Int64.logand result base_not in 
       carry := result >> base_log) in 
  if !carry = 0L then
    ret
  else
    Array.append ret [| !carry |]
  
(* a function to calculate the real size of the binary representation of a big int, example: size of 000100110 is 6 
   we assume the last block is non-zero (no trailing zeros) 
   this is only used by divrem *)
let actualsize n = 
  let rec size_ofint32 x accu = (* for an int32, packed into an int64, we assume x <> 0 *)
    if Int64.logand x 0x80000000L = 0L then size_ofint32 (Int64.shift_left x 1) (accu - 1) else accu
   in 
  if Array.length n = 0 then
    0
  else
    (size_ofint32 n[Array.length n - 1] 32) + 32 * (Array.length n - 1)
      
let rec divrem n d = (* NOTE: i give a non tail-recursive version for simplicity *)
  if compare n d < 0 then
    (zero, n)
  else 
    let difference_in_size = actualsize n - actualsize d in 
    let shift_length = (* calculates how much we have to shift d to make it "just" smaller than n *)
      difference_in_size - (if compare n (shift_left d difference_in_size) >= 0 then 0 else 1) in 
  
    (* q is the first "part" of the quotient *)
    let q = shift_left one shift_length in (* = 2^shift_length *)
  
    let r = sub n (shift_left d shift_length) in (* calculate the "step-remainder" *)

    let (qq, rr) = divrem r d in 
    (add q qq, rr)


let div a b =
  fst (divrem a b)

let rem a b =
  snd (divrem a b)

  
(* find multiplicative inverse of a in Z/pZ
 * https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Modular_integers 
 
   Our goal is to calculate t such that: t.a ≡ 1 [p] 
   The extended euclidean algorithm gives t as the last element of the series: tᵢ₊₁ = tᵢ₋₁ - qᵢtᵢ 
   Since the modulo operation is commutative with arithmetics, it is correct to do the calculations of the (tᵢ) modulo p .
   This is important for us because we don't have support for negative integers, so we have to change the (tᵢ) to accomodate subtraction. *)
let inverse a base =
  let (++) = add in 
  let (--) = sub in 
  let ( ** ) = mul in
  let t = ref zero in 
  let newt = ref one in
  let r = ref base in
  let newr = ref a in
  (* TODO: see if there is a simple way to turn that into primitive recursion *)
  let () = 
   while (fun () -> compare !newr zero > 0) 
    (fun () ->
       let quotient, remainder = divrem !r !newr in
       let do_sub a b =   (* subtraction modulo p, making the result positive *)
         if compare a b >= 0 then
           a -- b
         else if compare b base <= 0 then
           a ++ base -- b
         else 
           let brem = rem b base in
           if compare a brem >= 0 then
             a -- brem
           else
             a ++ (base -- brem)
        in
       let (:=) (ra, rb) (a, b) =
         let () = ra := a in 
         rb := b
        in
       let () = (r, newr) := (!newr, remainder) in 
       (t, newt) := (!newt, do_sub !t (quotient ** !newt))) in 
  !t


endmodule
